In [2]:
torch.setdefaulttensortype('torch.FloatTensor')


Out[2]:


In [3]:
trainset = torch.load('cifar10-train.t7')
testset = torch.load('cifar10-test.t7')
classes = {'airplane', 'automobile', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}

In [4]:
print(trainset)


Out[4]:
{
  data : ByteTensor - size: 10000x3x32x32
  label : ByteTensor - size: 10000
}

In [5]:
print(testset)


Out[5]:
{
  data : ByteTensor - size: 10000x3x32x32
  label : ByteTensor - size: 10000
}

In [6]:
function tabulate(A, f)
  local idx = {}
  local ndims = A:dim()
  local dim = A:size()
  idx[ndims] = 0
  for i=1, (ndims - 1) do
    idx[i] = 1
  end
  return A:apply(function()
    for i=ndims, 0, -1 do
      idx[i] = idx[i] + 1
      if idx[i] <= dim[i] then
        break
      end
      idx[i] = 1
    end
    return f(unpack(idx))
  end)
end

In [17]:
NearestNeighbor = {}

function NearestNeighbor:new(o)
    o = nearestNeighbor or {}
    setmetatable(o, self)
    self.__index = self
    return o
end

function NearestNeighbor:train(trainset)
    NearestNeighbor.trainset = trainset or {}
end

function NearestNeighbor:predict(testset, is_l2, k)
    is_l2 = is_l2 or false
    k = k or 1
    
    neighbors = {}
    
    if is_l2 then
        print('l2', k)
    else
        print('l1', k)
    end
    
    -- for i=1, testset.data:size(1) do
    for i=1, 10 do
        local min_diff = 9999999
        local cur_neighbors = {}
        
        -- for j=1, trainset.data:size(1) do
        for j=1, 1000 do
            local diff = trainset.data[j] - testset.data[i]
            
            diff = torch.FloatTensor(diff:size()):copy(diff)
            if is_l2 then
                diff_sum = math.sqrt(torch.pow(diff,2):sum())
            else
                diff_sum = torch.abs(diff):sum()
            end
            
            table.insert(cur_neighbors, {j, diff_sum})
        end
        
        table.sort(cur_neighbors, function(a,b) return a[2] < b[2] end)
        -- print(cur_neighbors)
        
        local counter, max, answer = {}, 0, nil
        
        for i=1,k do
            local j = cur_neighbors[i][1]
            local label = trainset.label[j]
            
            if counter[label] == nil then
                counter[label] = 1
            else
                counter[label] = counter[label] + 1
            end
            
            if max < counter[label] then
                max = counter[label]
                answer = label
            end
        end
        -- print(counter, answer)
        -- print("============")
        
        table.insert(neighbors, answer)
    end
    
    return neighbors
end

nn = NearestNeighbor:new()
nn:train(trainset)

t = os.clock()
results1 = nn:predict(testset, false)
print("Original time:", os.difftime(os.clock(), t))

t = os.clock()
results2 = nn:predict(testset, true)
print("Apply method:", os.difftime(os.clock(), t))

t = os.clock()
results3 = nn:predict(testset, false, 5)
print("Original time:", os.difftime(os.clock(), t))

t = os.clock()
results4 = nn:predict(testset, true, 5)
print("Apply method:", os.difftime(os.clock(), t))

function compare(results, testset)
    count = 0
    
    for idx, l in ipairs(results) do
        -- print(results[idx], testset.label[idx])
        if results[idx] == testset.label[idx] then
            count = count + 1
        end
    end
    
    return count/#results*100
end

print(string.format('accuracy: %.2f', compare(results1, testset)))
print(string.format('accuracy: %.2f', compare(results2, testset)))
print(string.format('accuracy: %.2f', compare(results3, testset)))
print(string.format('accuracy: %.2f', compare(results4, testset)))


Out[17]:
l1	1	
Original time:	1	
l2	1	
Apply method:	2	
l1	5	
Original time:	0	
l2	5	
Apply method:	2	
accuracy: 40.00	
accuracy: 40.00	
accuracy: 30.00	
accuracy: 20.00	

In [ ]: